{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 改进的神经网络\n",
    "\n",
    "### 1.加载MNIST数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 加载数据集\n",
    "import mnist_loader\n",
    "training_data, validation_data, test_data = mnist_loader.load_data_wrapper()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.使⽤交叉熵来对 MNIST 数字进⾏分类\n",
    "\n",
    "使⽤⼀个包含 30 个隐藏元的⽹络，⽽⼩批量数据的⼤⼩设置为 10。我们将学习速率设置为 η = 0:5， 然后训练 30 个迭代期。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 training complete\n",
      "Accuracy on evaluation data: 9095 / 10000\n",
      "Epoch 1 training complete\n",
      "Accuracy on evaluation data: 9260 / 10000\n",
      "Epoch 2 training complete\n",
      "Accuracy on evaluation data: 9320 / 10000\n",
      "Epoch 3 training complete\n",
      "Accuracy on evaluation data: 9344 / 10000\n",
      "Epoch 4 training complete\n",
      "Accuracy on evaluation data: 9427 / 10000\n",
      "Epoch 5 training complete\n",
      "Accuracy on evaluation data: 9444 / 10000\n",
      "Epoch 6 training complete\n",
      "Accuracy on evaluation data: 9473 / 10000\n",
      "Epoch 7 training complete\n",
      "Accuracy on evaluation data: 9474 / 10000\n",
      "Epoch 8 training complete\n",
      "Accuracy on evaluation data: 9494 / 10000\n",
      "Epoch 9 training complete\n",
      "Accuracy on evaluation data: 9487 / 10000\n",
      "Epoch 10 training complete\n",
      "Accuracy on evaluation data: 9496 / 10000\n",
      "Epoch 11 training complete\n",
      "Accuracy on evaluation data: 9497 / 10000\n",
      "Epoch 12 training complete\n",
      "Accuracy on evaluation data: 9514 / 10000\n",
      "Epoch 13 training complete\n",
      "Accuracy on evaluation data: 9526 / 10000\n",
      "Epoch 14 training complete\n",
      "Accuracy on evaluation data: 9508 / 10000\n",
      "Epoch 15 training complete\n",
      "Accuracy on evaluation data: 9519 / 10000\n",
      "Epoch 16 training complete\n",
      "Accuracy on evaluation data: 9492 / 10000\n",
      "Epoch 17 training complete\n",
      "Accuracy on evaluation data: 9520 / 10000\n",
      "Epoch 18 training complete\n",
      "Accuracy on evaluation data: 9499 / 10000\n",
      "Epoch 19 training complete\n",
      "Accuracy on evaluation data: 9506 / 10000\n",
      "Epoch 20 training complete\n",
      "Accuracy on evaluation data: 9513 / 10000\n",
      "Epoch 21 training complete\n",
      "Accuracy on evaluation data: 9501 / 10000\n",
      "Epoch 22 training complete\n",
      "Accuracy on evaluation data: 9524 / 10000\n",
      "Epoch 23 training complete\n",
      "Accuracy on evaluation data: 9514 / 10000\n",
      "Epoch 24 training complete\n",
      "Accuracy on evaluation data: 9520 / 10000\n",
      "Epoch 25 training complete\n",
      "Accuracy on evaluation data: 9521 / 10000\n",
      "Epoch 26 training complete\n",
      "Accuracy on evaluation data: 9460 / 10000\n",
      "Epoch 27 training complete\n",
      "Accuracy on evaluation data: 9504 / 10000\n",
      "Epoch 28 training complete\n",
      "Accuracy on evaluation data: 9517 / 10000\n",
      "Epoch 29 training complete\n",
      "Accuracy on evaluation data: 9517 / 10000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([],\n",
       " [9095,\n",
       "  9260,\n",
       "  9320,\n",
       "  9344,\n",
       "  9427,\n",
       "  9444,\n",
       "  9473,\n",
       "  9474,\n",
       "  9494,\n",
       "  9487,\n",
       "  9496,\n",
       "  9497,\n",
       "  9514,\n",
       "  9526,\n",
       "  9508,\n",
       "  9519,\n",
       "  9492,\n",
       "  9520,\n",
       "  9499,\n",
       "  9506,\n",
       "  9513,\n",
       "  9501,\n",
       "  9524,\n",
       "  9514,\n",
       "  9520,\n",
       "  9521,\n",
       "  9460,\n",
       "  9504,\n",
       "  9517,\n",
       "  9517],\n",
       " [],\n",
       " [])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import network2\n",
    "net = network2.Network([784, 30, 10], cost=network2.CrossEntropyCost)\n",
    "net.large_weight_initializer()\n",
    "net.SGD(training_data, 30, 10, 0.5, evaluation_data=test_data, monitor_evaluation_accuracy=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们得到了⼀个在测试集上达到 95.10% 准确率的⽹络。\n",
    "\n",
    "同样我们也来试试使⽤ 100 个隐藏神经元，交叉熵及其他参数保持不变的情况："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 training complete\n",
      "Accuracy on evaluation data: 9245 / 10000\n",
      "Epoch 1 training complete\n",
      "Accuracy on evaluation data: 9349 / 10000\n",
      "Epoch 2 training complete\n",
      "Accuracy on evaluation data: 9515 / 10000\n",
      "Epoch 3 training complete\n",
      "Accuracy on evaluation data: 9567 / 10000\n",
      "Epoch 4 training complete\n",
      "Accuracy on evaluation data: 9488 / 10000\n",
      "Epoch 5 training complete\n",
      "Accuracy on evaluation data: 9611 / 10000\n",
      "Epoch 6 training complete\n",
      "Accuracy on evaluation data: 9620 / 10000\n",
      "Epoch 7 training complete\n",
      "Accuracy on evaluation data: 9617 / 10000\n",
      "Epoch 8 training complete\n",
      "Accuracy on evaluation data: 9626 / 10000\n",
      "Epoch 9 training complete\n",
      "Accuracy on evaluation data: 9642 / 10000\n",
      "Epoch 10 training complete\n",
      "Accuracy on evaluation data: 9650 / 10000\n",
      "Epoch 11 training complete\n",
      "Accuracy on evaluation data: 9651 / 10000\n",
      "Epoch 12 training complete\n",
      "Accuracy on evaluation data: 9656 / 10000\n",
      "Epoch 13 training complete\n",
      "Accuracy on evaluation data: 9643 / 10000\n",
      "Epoch 14 training complete\n",
      "Accuracy on evaluation data: 9675 / 10000\n",
      "Epoch 15 training complete\n",
      "Accuracy on evaluation data: 9656 / 10000\n",
      "Epoch 16 training complete\n",
      "Accuracy on evaluation data: 9681 / 10000\n",
      "Epoch 17 training complete\n",
      "Accuracy on evaluation data: 9672 / 10000\n",
      "Epoch 18 training complete\n",
      "Accuracy on evaluation data: 9664 / 10000\n",
      "Epoch 19 training complete\n",
      "Accuracy on evaluation data: 9669 / 10000\n",
      "Epoch 20 training complete\n",
      "Accuracy on evaluation data: 9663 / 10000\n",
      "Epoch 21 training complete\n",
      "Accuracy on evaluation data: 9688 / 10000\n",
      "Epoch 22 training complete\n",
      "Accuracy on evaluation data: 9676 / 10000\n",
      "Epoch 23 training complete\n",
      "Accuracy on evaluation data: 9677 / 10000\n",
      "Epoch 24 training complete\n",
      "Accuracy on evaluation data: 9679 / 10000\n",
      "Epoch 25 training complete\n",
      "Accuracy on evaluation data: 9677 / 10000\n",
      "Epoch 26 training complete\n",
      "Accuracy on evaluation data: 9680 / 10000\n",
      "Epoch 27 training complete\n",
      "Accuracy on evaluation data: 9688 / 10000\n",
      "Epoch 28 training complete\n",
      "Accuracy on evaluation data: 9694 / 10000\n",
      "Epoch 29 training complete\n",
      "Accuracy on evaluation data: 9691 / 10000\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([],\n",
       " [9245,\n",
       "  9349,\n",
       "  9515,\n",
       "  9567,\n",
       "  9488,\n",
       "  9611,\n",
       "  9620,\n",
       "  9617,\n",
       "  9626,\n",
       "  9642,\n",
       "  9650,\n",
       "  9651,\n",
       "  9656,\n",
       "  9643,\n",
       "  9675,\n",
       "  9656,\n",
       "  9681,\n",
       "  9672,\n",
       "  9664,\n",
       "  9669,\n",
       "  9663,\n",
       "  9688,\n",
       "  9676,\n",
       "  9677,\n",
       "  9679,\n",
       "  9677,\n",
       "  9680,\n",
       "  9688,\n",
       "  9694,\n",
       "  9691],\n",
       " [],\n",
       " [])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = network2.Network([784, 100, 10], cost=network2.CrossEntropyCost)\n",
    "net.large_weight_initializer()\n",
    "net.SGD(training_data, 30, 10, 0.5, evaluation_data=test_data, monitor_evaluation_accuracy=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "net.save('network2.json')"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [Root]",
   "language": "python",
   "name": "Python [Root]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
